Spike sorting the 'Do It Yourself' way

Table of Contents

1. Introduction

1.1. An 'elementary' approach

What follows is a detailed spike sorting illustration. It is 'elementary' in the sense that no pre-cooked spike sorting package is used. Some simple functions, detailed at the end of the document, are used, but most of the analysis is carried out with basic functions and methods implemented in Python (Python 3) and its basic scientific modules: numpy and scipy, as well as h5py (to deal with HDF5 files in Python). This is a deliberate choice, Python and, more to the point, its scientific modules are very unstable. They are evolving too fast with little concern for backward compatibility, meaning that codes using 'high level' modules have to be updated every 6 months or so—or containers like Docker or Singularity have to be used—to keep the code running. A way of avoiding or decreasing these 'maintenance costs' is to limit dependencies since the core modules are more stable than the high level ones. This is what has been done here.

1.2. Bird eye view

Spike sorting as described in this document consists in a succession of 'simple' tasks illustrated by Fig. 1.

figure-d-ensemble-b.png

Figure 1: Spike sorting as a succession of tasks: A, events detection (probable action potentials); B, cuts / windows, one on each site, of 'well-chosen' length (here 45 sampling points) around the detected extremes, this collection of four (as we have here four sites) cuts constitutes an event (our event space is here \(\mathbb{R}^{180}\) as we have 4×45 amplitudes per event); C, the first 200 detected events aligned on their valley (in red events made of superpositions); D, dimension reduction, here the projection of the sample on a plane of the subspace defined by the first three principal components; E, clustering with the k-means method and 10 centers; F, the motifs (centers of the clusters defined in the previous step) corresponding to the 5 'largest' neurons (the 5 different colors) on each of the 4 sites; G1, return to the raw data (black trace) and attribution of a motif to each local extremum generating a prediction (red trace); G2, in black the difference between the black and red traces of G1, a detection of local extrema is performed again and the closest motif is assigned to each extremum, giving rise to a new prediction (red trace); G3, the difference between the black and red traces of G2, we continue this 'peeling' procedure until there is nothing left identifiable to any of the motifs of the collection.

A brief outline (referring to Fig. 1):

  • A, 100 ms of data recorded at the 4 sites of a tetrode. The data were filtered between 300 Hz and 5 kHz before being digitized with a sampling frequency of 15 kHz (more precisely at 14881 Hz). They were then normalized by dividing the signal amplitude at each site by a robust estimator, the median absolute deviation, of the noise standard deviation \(\sigma{}_{noise}\).
  • B, after detection of the candidate action potentials as sufficiently large local extrema in absolute value, cuts are made on each of the four sites—the length of the cuts must be judiciously chosen, it is here 3 ms (i.e. 45 sampling points) with 1 ms before the extremum and 2 ms after—, this group of four cuts determines an event defined in the event space, here \(\mathbb{R}^{180}\) since we have 4×45=180 amplitudes per event—the recording conditions are detailed in Pouzat et al. (2002)—.
  • C, the first 200 events at each of the four sites have been aligned to their extremum, the events in red show overlaps—two or more neurons emitted an action potential during the same time window—.
  • D, projection of the events on a plane defined by the first three principal components.
  • E, the same projection with a coloring of the five 'largest neurons' (in the sense of the L1 norm of their motif) obtained by the k-means method with 10 centers.
  • F, motifs of each of the 5 largest neurons (the 5 colors from left to right) at each of the four sites.
  • G1, back to the raw data (black trace) and attribution of a motif to each local extremum generating a prediction (red trace).
  • G2, in black the difference between the black and red traces of G1, a detection of local extrema is performed again and the closest motif is assigned to each extremum, giving rise to a new prediction (red trace).
  • G3, the difference between the black and red traces of G2, we continue this 'peeling' procedure until nothing identifiable to any of the patterns of the collection remains.

1.3. Why go through that if you're not a sorting practitioner?

This document should (hopefully) be of interest for a wider audience that the spike sorting practitioners. The task to be performed requires indeed many statistical and applied maths tools and concepts that are useful in a much wider context, like:

  • data quantization and sampling
  • summary statistics
  • robust statistics
  • dimension reduction
  • clustering
  • approximation via Taylor expansions
  • linear and nonlinear regressions
  • clustering
  • graphical exploration
  • classification

1.4. What is needed?

To re-run the following analysis you will need:

  • Python 3
  • numpy
  • scipy
  • matplotlib
  • h5py
  • GGobi (to get the dynamic displays).

The source files of this document are available on GitLab in the spike-sorting-the-diy-way repository. A PDF version of this document is available.

2. Importing the required modules and loading the data

2.1. Downloading Python code

The individual functions developed for this kind of analysis are defined at the end of this document (Sec. 13) and it is assumed that they are collected in file named sorting_with_python.py located in your Python working directory. A copy of this file can downloaded as follows:

import urllib.request as url  #import the module
addr = 'https://gitlab.com/c_pouzat/spike-sorting-the-diy-way/-/raw/master/org/code/sorting_with_python.py'
url.urlretrieve(addr,'sorting_with_python.py')

This custom module is then imported in Python's with:

import sorting_with_python as swp

2.2. Importing the 'usual' modules

We are going to use numpy and pylab (we will also use pandas later on, but to generate only one figure so you can do the analysis without it). We are also going to use the interactive mode of the latter:

import numpy as np
import matplotlib.pylab as plt
plt.ion()

2.3. Downloading and reading the data: dealing with HDF5 files

We are going to use a small extract of a dataset recorded from the locust, Schistocerca americana, antennal lobe (the insect first olfactory relay). The original dataset is available on zenodo in file locust20010214_part1.hdf5 in the group (that's the HDF5 terminology): /Spontaneous_2/trial_30. The extract is contained in file Locust4demo.h5 in our GitLab repository, spike-sorting-the-diy-way. We download the data with:

addr = 'https://gitlab.com/c_pouzat/spike-sorting-the-diy-way/-/raw/master/org/data/Locust4demo.h5'
url.urlretrieve(addr,'Locust4demo.h5')

As already mentioned the data are in HDF5 format, a modern, widely used format designed to store 'complex' datasets together with metadata in a hierarchical way. One can therefore have groups of datasets related in some way. Think of the file / folder structure on your computer: datasets are files and groups are folders / directories. Like the files / folders structure on your computer, groups can be contained in groups, etc.

When the data were recorded, our custom developed data acquisition software stored the recordings from each channel in a individual file (as integers with a 32 bit precision), so 16 recording channels gave rise to 16 different files. When the data were put in HDF5 format, we chose to have as groups, continuous acquisition epochs—we were recording typically 30 s to 1 minute of spontaneous activity followed by 25 times 30 s while applying a specific odor, etc.—. The example file we are going to use, Locust4demo.h5, contains a single group; it is therefore 'small' and faster to download. We are going to use the dedicated module h5py to open and 'manipulate' the data file. We start by importing the module, opening the file and getting a glimpse at its content with:

# Import h5py module
import h5py
# Open file Locust4demo.h5
locust = h5py.File('Locust4demo.h5','r')
# Glimps at content
locust.visititems(print)
spont <HDF5 group "/spont" (4 members)>
spont/site_1 <HDF5 dataset "site_1": shape (431548,), type "<i4">
spont/site_2 <HDF5 dataset "site_2": shape (431548,), type "<i4">
spont/site_3 <HDF5 dataset "site_3": shape (431548,), type "<i4">
spont/site_4 <HDF5 dataset "site_4": shape (431548,), type "<i4">

The file contains a 'top' group named 'spont' (this is a recording of spontaneous activity). This group contains 4 datasets: 'site_1', 'site_2', 'site_3', 'site_4'; all with the same 'shape' (one dimensional arrays or vectors with 431548 elements), containing integers coded with 4 Bytes (32 bit). We can check if there are metadata associated to this group with:

list(locust['spont'].attrs.keys())
['README']

Let us see the content of this 'README':

locust['spont'].attrs['README']
('This is an extract from file locust20010214_part1.hdf5 from zenodo '
 'repository: https://zenodo.org/record/21589#.XniDEu4o8QU. The data are from '
 'the group: /Spontaneous_2/trial_30 in the original file. These data were '
 'recorded from the locust, Schistocerca americana, antennal lobe by C Pouzat '
 'and O Mazor on February 2 2001. The high-pass filter was set at 300 Hz and '
 'the low pass filter was set at 5 kHz. The data were sampled at 15 kHz.')

We create next a copy of the data—that we store in a list named data (each list element will be a 1 dimensional array)—since we are going to modify (normalize, etc) the original data. Once we are done, we close the data file:

# Create a list with dataset names
dset_names = ['spont/site_' + str(i) for i in range(1,5)]
# Load the data in a list of numpy arrays
data = [locust[n][...] for n in dset_names]
# Close data file
locust.close()

3. Preliminary analysis

We are going to start our analysis by some "sanity checks" to make sure that nothing "weird" happened during the recording.

3.1. Five number summary

We should start by getting an overall picture of the data like the one provided by the mquantiles method of module scipy.stats.mstats using it to output a five-number summary. The five numbers are the minimum, the first quartile, the median, the third quartile and the maximum:

from scipy.stats.mstats import mquantiles
np.set_printoptions(precision=3)
[mquantiles(x,prob=[0,0.25,0.5,0.75,1]) for x in data]
[array([  98., 2011., 2058., 2105., 3634.]),
 array([ 610., 2015., 2057., 2099., 3076.]),
 array([ 762., 2017., 2057., 2095., 2887.]),
 array([ 475., 2014., 2059., 2102., 3143.])]

In the above result, each row corresponds to a recording channel, the first column contains the minimal value; the second, the first quartile; the third, the median; the fourth, the third quartile; the fifth, the maximal value.

3.2. Quantized data

Like almost all (experimental) data that you are likely to encounter, these data have been digitized, i.e. a continuous quantity, here a potential, has been discretized or quantized. The continuous potentials have thus been stored as discrete values by means of an analog-to-digital converter. Concretely, in our case, the extracellular potentials, after amplifications were put in the closest class, where the classes are obtained by dividing the interval [-5V,+5V] in 2\(^{12}\) classes (the acquisition card having a 'precision' of 12 bit, as specified in the article that first used these data). The data stored in the files we downloaded and which are now in the data list actually contain the class indices. We could convert these quantities into volts by multiplying them by 10/2\(^{12}\) and then subtracting 5. Clearly, this quantization implies a loss of precision since all the values between the two bounds delimiting a class are by definition in this class; we speak in this context of quantization noise. To obtain the best possible precision, the experimenters will try to have acquisition cards with the greatest number of bits (16 rather than 12, etc.) and they will adjust their amplification so that their 'peak-to-peak' signal (i.e., the domain covered by the amplitude) is as close as possible to the limits of their acquisition card, in this case [-5V,+5V]. This can sometimes lead them to over-amplify and that will result in a truncated signal: all potentials < -5V will be in class 0 and all potentials > +5V will be in class \(2^{12}-1 = 4095\).

3.3. How to read the 5 number summary?

Here we are dealing with data that have been filtered with a 300 Hz high-pass filter, i.e. all signal components with a frequency below 300 Hz have been eliminated. We then expect the signal to be centered on 0 V, i.e. in the class \((2^{12}-1)/2=2047\) (we are talking about integer division) and we see that both the median is very close to this value. We also see that the maximum value 3634 is lower but 'close' to the maximum possible value 4095 (less than a factor of 2 difference). Similarly for the minimum value, 98 which is larger, but still close to the minimum possible 0. This means that the gain of the amplifier has been correctly adjusted to the specifications of the acquisition card - here I congratulate myself shamefully! - since all the coding bits are used: \(3634=2^{11}+2^{10}+2^9+2^5+2^4+2\).

Another potential 'pathology' results from a saturation of the amplifier used before the quantization stage. You are all familiar with that: when a person speaks into a microphone while being too close to the loudspeaker broadcasting his or her voice, the positive feedback amplifies the signal until the amplifier is saturated and produces a very unpleasant hissing sound that continues even after the speaker has covered the microphone. In other words, when there is saturation, it usually lasts a 'long time' (this is due to the electronics of the amplifier) and on the 5-number summary it will often appear as a first quartile very close to the minimum and a third quartile very close to the maximum. This is not the case here so everything is fine so far.

We also pay attention to possible non-uniformities of the statistics between columns (recording sites); for close recording sites we expect to have distributions that are not identical, but similar. In particular, the interquartile range, third quartile - first quartile, provides a robust estimate of the signal dispersion and should not be very different from one column to another.

3.4. Robust statistics

Why use the median rather than the mean as an estimator of the location, or 'typical value', of a random variable? Let's imagine that we observe a sequence of draws following a uniform distribution between -0.5 and +0.5—but we don't know that, that's precisely why we make our observations—and that out of 100 observations, there is one that is 'corrupted', because it was for example wrongly written down and it appears in our data table as having the value 100. If we calculate the empirical mean of such a sample we obtain a random variable whose expectation is 1 (do the calculation yourself), if we calculate the median, we obtain another random variable whose expectation is 0, which is indeed the expectation of the uncorrupted data. It is in this sense that the median is a robust estimator of location and the mean is not. You will be able to convince yourself with a small simulation (without forgetting to fix the seed of your pseudorandom number generator):

import statistics
import random
random.seed(20061001) # seed the generator
samp = [random.random()-0.5 for i in range(99)]
samp += [100]
[round(statistics.mean(samp),3), round(statistics.median(samp),3)]

If we compute the standard deviation of our sample, we have a new random variable whose expectation is 10 (again, do the maths), while the standard deviation of the uncorrupted law is \(1/\sqrt{12}\). So we would like to have the 'equivalent' of the median with respect to the mean, when we estimate the dispersion (or scale) of a sample. There are several estimators that meet our expectation and we will use the simplest one, the median absolute deviation (MAD) which is, to within a factor, the median of the absolute deviations from the median (the factor is used to guarantee that the expectation of the estimator applied to a Gaussian sample IID, is indeed the standard deviation of the underlying distribution):

import math
samp_sd = statistics.stdev(samp)
samp_median = statistics.median(samp)
samp_mad = 1.4826*statistics.median([abs(obs-samp_median) for obs in samp])
[round(samp_sd,3), round(samp_mad,3), round(1/math.sqrt(12),3)]

The conclusion of this is that when you are analyzing real data and have the slightest suspicion that some observations may have been corrupted, always use the median as the location estimator and the MAD as the dispersion estimator.

4. Look at your data!

Plotting the data for interactive exploration is trivial. The only trick is to add (or subtract) a proper offest (that we get here using the maximal value of each channel from our five-number summary), this is automatically implemented in our plot_data_list function:

tt = np.arange(0,len(data[0]))/1.5e4
swp.plot_data_list(data,tt,0.1)

The first channel is drawn as is, the second is offset downward by the sum of its maximal value and of the absolute value of the minimal value of the first, etc. We then get something like Fig. 2.

WholeRawData.png

Figure 2: The whole (20 s) Locust antennal lobe data set.

It is also good to "zoom in" and look at the data with a finer time scale (Fig. 3) with:

plt.xlim([0,0.2])

First200ms.png

Figure 3: First 200 ms of the Locust data set.

5. Data renormalization

We are going to use a median absolute deviation (MAD) based renormalization. The goal of the procedure is to scale the raw data such that the noise SD is approximately 1. Since it is not straightforward to obtain a noise SD on data where both signal (i.e., spikes) and noise are present, we use this robust type of statistic for the SD:

data_mad = list(map(swp.mad,data))
data_mad
[69.6822, 62.2692, 57.8214, 65.2344]

And we normalize accordingly (we also subtract the median which is not 0!):

data = list(map(lambda x: (x-np.median(x))/swp.mad(x), data))

We can check on a plot (Fig. 4) how MAD and SD compare:

plt.plot(tt,data[0],color="black",lw=0.5)
plt.xlim([0,0.2])
plt.ylim([-17,13])
plt.axhline(y=1,color="red")
plt.axhline(y=-1,color="red")
plt.axhline(y=np.std(data[0]),color="blue",linestyle="dashed")
plt.axhline(y=-np.std(data[0]),color="blue",linestyle="dashed")
plt.xlabel('Time (s)')
plt.ylim([-20,15])

site1-with-MAD-and-SD.png

Figure 4: First 200 ms on site 1 of the Locust data set. In red: +/- the MAD; in dashed blue +/- the SD.

5.1. A quick check that the MAD "does its job"

We can check that the MAD does its job as a robust estimate of the noise standard deviation by looking at Q-Q plots of the whole traces normalized with the MAD and normalized with the "classical" SD (Fig. 5):

dataQ = map(lambda x:
            mquantiles(x, prob=np.arange(0.01,0.99,0.001)),data)
dataQsd = map(lambda x:
              mquantiles(x/np.std(x), prob=np.arange(0.01,0.99,0.001)),
              data)
from scipy.stats import norm
qq = norm.ppf(np.arange(0.01,0.99,0.001))
plt.plot(np.linspace(-3,3,num=100),np.linspace(-3,3,num=100),
         color='grey')
colors = ['black', 'orange', 'blue', 'red']
for i,y in enumerate(dataQ):
    plt.plt.plot(qq,y,color=colors[i])

for i,y in enumerate(dataQsd):
    plt.plot(qq,y,color=colors[i],linestyle="dashed")

plt.xlabel('Normal quantiles')
plt.ylabel('Empirical quantiles')

check-MAD.png

Figure 5: Performances of MAD based vs SD based normalizations. After normalizing the data of each recording site by its MAD (plain colored curves) or its SD (dashed colored curves), Q-Q plot against a standard normal distribution were constructed. Colors: site 1, black; site 2, orange; site 3, blue; site 4, red.

If you want to save the figure use:

We see that the behavior of the "away from normal" fraction is much more homogeneous for small, as well as for large in fact, quantile values with the MAD normalized traces than with the SD normalized ones. If we consider automatic rules like the three sigmas we are going to reject fewer events (i.e., get fewer putative spikes) with the SD based normalization than with the MAD based one.

6. Detect spike candidates

6.1. What to detect?

The characteristic attribute of the data we have just visualized is the presence of 'spikes': brief and 'large amplitude' deviations of the extracellular potential. We expect that neuronal potentials do indeed generate such deviations —for a detailed justification of this not necessarily obvious statement, see chapter Multi-Unit Recording: Fundamental Concepts and New Directions—and we further expect that these deviations are larger on the negative side than on the positive side. This leads us to look for 'valleys' rather than 'peaks'. But we must not ignore the fact that any recording is subject to 'noise' which, in our case, has both biological and external origins—the laboratory is bathed in electromagnetic waves, some of which have frequencies close to those of our signal, and we cannot always completely isolate our experimental rig from them— or electronic (amplifier). We will therefore look for valleys, but we are not interested in detecting all of them, we will focus only on those of 'sufficiently' large amplitude. Here, there is a choice to be made; a choice that is partly arbitrary and that must therefore be documented. We will look for local minima whose value is lower than a threshold that we will choose as a multiple of our noise level. This explains why we were concerned, in the previous section, to have as objective and reliable an estimate of the noise level as possible. The first decision to make is therefore the detection threshold. Figure 4 also shows quite clearly that a fairly high frequency noise is present; however, this can cause problems when we are looking for the position of an extremum of an underlying lower frequency signal: we are looking for a place where our signal is locally flat and 'fast fluctuating' noise is superimposed on it, which can lead us to miss the true extremum by choosing a neighboring point which, by chance due to the noise, is more extreme. We will mitigate the effect of high frequency noise by performing our detection on a filtered version of the data; we will replace each observed value by the average of itself and some of its closest neighbors. We will apply what is called a box filter. We will then have to make a second decision, that of the filter length (how many points should we use on each side).

You should not conclude from this that you should always filter the data and, even less, always filter it with a box filter. We are implementing a practical solution whose ultimate justification is not a nice theorem, but the fact that 'it works'. This means that you are free to explore other solutions, such as other types of filters; and you can do this because the analysis is done in a general purpose environment like Python which imposes very few constraints on you (in this respect R or Matlab would also be very suitable). If you were to use, as most practitioners of spike sorting do, a 'push-button' software, you would be completely constrained by the choices of the designer(s).

Since you 'can do whatever you want' because you are using Python, you could also try to work on the derivative of your data rather than the raw data—you can get the derivative numerically, we will do that later, by replacing each observation by the difference of its right and left neighbor before dividing by 2—so detecting peaks may be more efficient than detecting valleys. Keep in mind that there is no ideal solution that works in all cases. To do a good quality sorting you should, in my opinion, use a software that does not put you in a bind and make full use of the freedom that your tool gives you.

6.2. Where to detect?

We have one more issue to discuss before we get into the actual detection. Our penultimate figure (the first 200 ms of data, Fig. 3) clearly shows, fortunately, that most of the (candidate) action potentials are visible on several sites (on all four sites for the first event for example). If we perform a detection on each of the 4 sites, we might have slightly different extrema times due to the high frequency noise mentioned above (even after filtration), which would require choosing an event time in a second step.

What we are going to do here—once again, this is a practical solution that can be criticized and improved—is first to 'truncate' our signal at each site by forcing to 0 any value above our detection threshold (we are looking for valleys), then to sum, at each time, the truncated trace of each site, so as to have only one trace. The valleys that we will then observe will automatically be a 'compromise' between what happens on each of the sites with a stronger weight given to the site on which the event is the greatest.

6.3. How to proceed?

We are going to filter the data slightly using a 'box' filter of length 5. That is, the data points of the original trace are going to be replaced by the average of themselves with their four nearest neighbors. We will then scale the filtered traces such that the MAD is one on each recording sites and keep only the parts of the signal bellow our threshold set at -4:

from scipy.signal import fftconvolve
from numpy import apply_along_axis as apply
data_filtered = apply(lambda x:
                      fftconvolve(x,np.array([1,1,1,1,1])/5.,'same'),
                      1,np.array(data))
dfiltered_mad_original = apply(swp.mad,1,data_filtered)
data_filtered = (data_filtered.T / \
                 dfiltered_mad_original).T
data_filtered_full = data_filtered.copy()
data_filtered[data_filtered > -4] = 0

Notice that at the normalization stage (division by the MAD), transposition is used. For a detailed explanation of this subtle point check the numpy documentation on broadcasting. We can see the difference between the raw trace and the filtered and rectified one (Fig. 6) on which spikes are going to be detected with:

plt.plot(tt, data[0],color='black',lw=0.5)
plt.axhline(y=-4,color="blue",linestyle="dashed")
plt.plot(tt, data_filtered[0,],color='red',lw=0.5)
plt.xlim([0,0.2])
plt.ylim([-20,15])
plt.xlabel('Time (s)')

compare-raw-and-filtered-data.png

Figure 6: First 200 ms on site 1 of data set data. The raw data are shown in black, the detection threshold appears in dashed blue and the filtered and truncated trace on which spike detection is going to be preformed appears in red.

In order to have a 'single time per event' when several recording sites are used simultaneously, we sum these filtered and rectified traces from our four channels and we detect on that sum (bottom trace of Fig. 7):

fig, axs = plt.subplots(nrows=5, sharex=True)
for i in range(4):
    ax = axs[i]
    ax.plot(tt, data_filtered_full[i],color='black',lw=0.5)
    ax.axhline(y=-4,color="blue",linestyle="dashed")
    ax.plot(tt, data_filtered[i],color='red',lw=0.5)

# sum across channels
data_sum = np.sum(data_filtered, axis=0)

ax = axs[4]
ax.plot(tt, data_sum, color='black')
ax.set_xlim([0.024,0.090])
ax.set_ylim([-100,10])
ax.set_xlabel('Time (s)')

sum-channel-filtered-data-all.png

Figure 7: First 200 ms on site 1 of data set data. The raw data are shown in black, the detection threshold appears in dashed blue and the filtered and truncated trace on which spike detection is going to be preformed appears in red.

We now use function peak on the sum of the rows of our filtered and truncated version of the data (since peak looks for peaks and we want to detect valleys, we use the opposite of data_filtered):

sp0 = swp.peak(-data_filtered.sum(0))

Giving 1950 spikes, a mean inter-event interval of 221 sampling points, a standard deviation of 204 sampling points, a smallest inter-event interval of 16 sampling points and a largest of 2904 sampling points.

6.4. Interactive spike detection check

We can then check the detection quality with (Fig. 8):

swp.plot_data_list_and_detection(data,tt,sp0)
plt.xlim([0,0.125])

check-spike-detection.png

Figure 8: First 125 ms of data set data. The raw data are shown in black, the detected events are signaled by red dots (a dot is put on each recording site at the amplitude on that site at that time).

7. Cuts

7.1. Key data generation hypothesis

If we set out to do a task like sorting action potentials, we have good reason to believe that there are multiple neurons active in our recording; otherwise our analysis would be complete as soon as the events were detected. If there are several neurons in our recording, we expect to observe several different forms of action potentials in our data. To move forward, we will start with a simple (if not simplistic) assumption of data generation:

  1. Each neuron (visible in our recordings) generates a 'signal', i.e. a given shape or function of time, always the same, on each of the recording sites—to remove any ambiguity in the above, the shape associated with a given neuron is generally not the same on each site (there are as many shapes as there are recording sites for a given neuron), but each time this neuron generates an action potential, the same collection of (noisy) shapes is observed—; we will refer to the collection of shapes associated with a neuron as a motif or reference waveform.
  2. The neurons discharge (or generate action potentials) according to point processes that are 'not too correlated'—the estimation of the intensity of these processes is in fact what occupies us most often once the sorting of the action potentials is done—.
  3. Each neuron generates a 'signal' on each of the recording sites, which is a function of time obtained by convolution of the realization of its point process with its motif on the site in question.
  4. The signals generated by the different neurons are added together on each site; The observed data are a noisy version of this sum of signals—the noise is never white since the data have been filtered before being digitized—.

This data generation model was first spelled out a long time ago in a paper by Bill Roberts entitled: Optimal recognition of neuronal waveforms (1979).

7.2. A remark

The strong assumption we have just made is that the motif associated with each neuron is constant over time. There are many cases where this assumption is clearly false, but it is not the case here as we will see soon.

We could be even more restrictive in our assumptions and consider that a given shape is associated to each neuron and that what we observe on each of our sites when a given neuron emits an action potential is this shape multiplied by a factor specific to the neuron and to the recording site. Instead of having, for each neuron, as many shapes as sites, we could then have one shape and as many multipliers as sites. This assumption is generally valid (see an illustration below Fig. 19), but it does not dramatically simplify the analysis that follows while making the codes more complicated. We will therefore not make this assumption. However, since the factors are determined by geometry (relative positions of the neuron and the recording sites) and since no two neurons can be at exactly the same position, we expect to have different motifs for our different neurons. This does not imply that we will be able to discriminate these motifs, that depends on the ratio of inter-motif differences to noise level.

7.3. Choosing the events' length

Once we are satisfied with our detection, we need to build our events. We will do this by taking a 'window' or a 'cut', i.e. a piece of data, around each detection time—which I will refer to as reference time in the following—and this on each of the 4 sites. An event will thus consist of 4 windows / cuts, all of the same length and 'synchronous' in the sense that the element \(i\) of each cut will be a measurement of extracellular potential at the same (real) time on each of the sites. The question we have to consider is then twofold: what is the 'right' cut length and, what is the 'right' position of the reference time within the cut?

We are clearly interested in using cuts as short as possible because :

  • it will decrease the memory space occupied by our sample,
  • it will decrease the computation time,
  • it will decrease the number of overlapping events (when two or more neurons emit an action potential in the same window).

7.4. Implementation

We go back to the normalized raw data, the filtered ones are used for spike detection only. We can start with very (too) long cuts, aligning the events and estimating the 'central' event by the pointwise median. We will compute the pointwise MAD at the same time. Then, as long as the difference between the patterns is large enough, the variability around the central event quantified by the MAD should be above the level of the recording noise (which we set to 1, in principle, by our previous normalization procedure). So we only need to look at how many sample points to the left (before) and right (after) of the reference time are needed for the MAD to return to 1.

To create our collection of cuts (or set of events), we use the mk_events function. This one takes spike sequence (like sp0) as first parameter, the data matrix as second parameter, the number of sampling points to keep before the reference time as third parameter and the number of sampling points to keep after the reference time as fourth parameter. This function creates the cuts associated with each specified time of the spike sequence from the data and returns a vector where all the cuts of a given event are put end to end; these vectors are 'stacked' in the form of a matrix which is essentially what the function returns. So we start by creating a sample with long windows (100 points), then we calculate the median event and its MAD :

evts = swp.mk_events(sp0,np.array(data),49,50)
evts_median=apply(np.median,0,evts)
evts_mad=apply(swp.mad,0,evts)
plt.plot(evts_median, color='red', lw=2)
plt.axhline(y=0, color='black')
for i in np.arange(0,400,100): 
    plt.axvline(x=i, color='black', lw=2)

for i in np.arange(0,400,10): 
    plt.axvline(x=i, color='grey')

plt.plot(evts_median, color='red', lw=2)
plt.plot(evts_mad, color='blue', lw=2)

check-MAD-on-long-cuts.png

Figure 9: Robust estimates of the central event (red) and of the sample's dispersion around the central event (blue) obtained with "long" (100 sampling points) cuts. We see clearly that the dispersion is back to noise level 15 points before the peak and 30 points after the peak.

Fig. 9 clearly shows that starting the cuts 15 points before the peak and ending them 30 points after should fulfill our goals. We also see that the central event slightly outlasts the window where the MAD is larger than 1.

7.5. Events

Once we are satisfied with our spike detection, at least in a provisory way, and that we have decided on the length of our cuts, we proceed by making cuts around the detected events:

evts = swp.mk_events(sp0,np.array(data),14,30)

We can visualize the first 200 events with (Fig. 10):

swp.plot_events(evts,200)

first-200-of-evts.png

Figure 10: First 200 events of evts. Cuts from the four recording sites appear one after the other. The background (white / grey) changes with the site. In red, robust estimate of the "central" event obtained by computing the pointwise median. In blue, robust estimate of the scale (SD) obtained by computing the pointwise MAD.

7.6. Getting "clean" events

Our spike sorting has two main stages, the first one consist in estimating a model and the second one consists in using this model to classify the data. Our model is going to be built out of reasonably "clean" events. Here by clean we mean events which are not due to a nearly simultaneous firing of two or more neurons; and simultaneity is defined on the time scale of one of our cuts. When the model will be subsequently used to classify data, events are going to decomposed into their (putative) constituent when they are not "clean", that is, superposition are going to be looked and accounted for.

In order to eliminate the most obvious superpositions we are going to use a rather brute force approach, looking at the sides of the central peak of our median event and checking if individual events are not too low there, that is do not exhibit extra valleys. We first define a function doing this job:

def good_evts_fct(samp, thr=3):
    samp_med = apply(np.median,0,samp)
    samp_mad = apply(swp.mad,0,samp)
    below = samp_med < 0
    samp_r = samp.copy()
    for i in range(samp.shape[0]): samp_r[i,below] = 0
    samp_med[below] = 0
    res = apply(lambda x:
                np.all(abs(x-samp_med) < thr*samp_mad),
                1,samp_r)
    return res

We then apply our new function to our sample using a threshold of 8 (set by trial and error):

goodEvts = good_evts_fct(evts,8)

Out of 1950 events we get 1881 "good" ones. As usual, the first 200 good ones can be visualized with (Fig. 11):

swp.plot_events(evts[goodEvts,:][:200,:])

first-200-clean-of-evts.png

Figure 11: First 200 "good" events of evts. Cuts from the four recording sites appear one after the other. The background (white / grey) changes with the site. In red, robust estimate of the "central" event obtained by computing the pointwise median. In blue, robust estimate of the scale (SD) obtained by computing the pointwise MAD.

8. Dimension reduction

8.1. Principal Component Analysis (PCA)

Our events are living right now in an 180 dimensional space (our cuts are 45 sampling points long and we are working with 4 recording sites simultaneously). It turns out that it hard for most humans to perceive structures in such spaces. It also hard, not to say impossible with a realistic sample size, to estimate probability densities (which is what model based clustering algorithms are actually doing) in such spaces, unless one is ready to make strong assumptions about these densities. It is therefore usually a good practice to try to reduce the dimension of the sample space used to represent the data. We are going to that with principal component analysis (PCA), using it on our "good" events.

from numpy.linalg import svd
varcovmat = np.cov(evts[goodEvts,:].T)
u, s, v = svd(varcovmat)

With this "back to the roots" approach, u should be an orthonormal matrix whose column are made of the principal components (and v should be the transpose of u since our matrix varcovmat is symmetric and real by construction). s is a vector containing the amount of sample variance explained by each principal component.

8.2. Exploring PCA results

PCA is a rather abstract procedure to most of its users, at least when they start using it. But one way to grasp what it does is to plot the mean event plus or minus, say five times, each principal components like (Fig. 12):

evt_idx = range(180)
evts_good_mean = np.mean(evts[goodEvts,:],0)
for i in range(4):
    plt.subplot(2,2,i+1)
    plt.plot(evt_idx,evts_good_mean, 'black',evt_idx,
             evts_good_mean + 5 * u[:,i],
             'red',evt_idx,evts_good_mean - 5 * u[:,i], 'blue')
    plt.axis('off')
    plt.title('PC' + str(i) + ': ' + str(round(s[i]/sum(s)*100)) +'%')

explore-evts-PC0to3.png

Figure 12: PCA of evtsE (for "good" events) exploration (PC 1 to 4). Each of the 4 graphs shows the mean waveform (black), the mean waveform + 5 x PC (red), the mean - 5 x PC (blue) for each of the first 4 PCs. The fraction of the total variance "explained" by the component appears in the title of each graph.

We can see on Fig. 12 that the first 3 PCs correspond to pure amplitude variations. An event with a large projection (score) on the first PC is smaller than the average event on recording sites 1, 2 and 3, but not on 4. An event with a large projection on PC 1 is larger than average on site 1, smaller than average on site 2 and 3 and identical to the average on site 4. An event with a large projection on PC 2 is larger than the average on site 4 only. PC 3 is the first principal component corresponding to a change in shape as opposed to amplitude. A large projection on PC 3 means that the event as a shallower first valley and a deeper second valley than the average event on all recording sites.

We now look at the next 4 principal components:

for i in range(4,8):
    plt.subplot(2,2,i-3)
    plt.plot(evt_idx,evts_good_mean, 'black',
             evt_idx,evts_good_mean + 5 * u[:,i], 'red',
             evt_idx,evts_good_mean - 5 * u[:,i], 'blue')
    plt.axis('off')
    plt.title('PC' + str(i) + ': ' + str(round(s[i]/sum(s)*100)) +'%')

explore-evts-PC4to7.png

Figure 13: PCA of evtsE (for "good" events) exploration (PC 4 to 7). Each of the 4 graphs shows the mean waveform (black), the mean waveform + 5 x PC (red), the mean - 5 x PC (blue). The fraction of the total variance "explained" by the component appears in between parenthesis in the title of each graph.

An event with a large projection on PC 4 (Fig. 13) tends to be "slower" than the average event. An event with a large projection on PC 5 exhibits a slower kinetics of its second valley than the average event. PC 4 and 5 correspond to effects shared among recording sites. PC 6 correspond also to a "change of shape" effect on all sites except the first. Events with a large projection on PC 7 rise slightly faster and decay slightly slower than the average event on all recording site. Notice also that PC 7 has a "noisier" aspect than the other suggesting that we are reaching the limit of the "events extra variability" compared to the variability present in the background noise.

8.3. Static representation of the projected data

We can build a scatter plot matrix showing the projections of our "good" events sample onto the plane defined by pairs of the few first PCs. Starting with the first 4 we get (Fig. 14):

evts_good_P0_to_P3 = np.dot(evts[goodEvts,:],u[:,0:4])
swp.splom(evts_good_P0_to_P3.T,
          ['PC 0','PC 1','PC 2','PC 3'],
          marker='.',linestyle='None',
          alpha=0.2,ms=1) 

Fig4.png

Figure 14: Scatter plot matrix of the projections of the good events in evtsE onto the planes defined by the first 4 PCs.

Clear structures (separated clusters) can be seen on all the projections. We check for structures on the next for principal components (Fig. 15):

evts_good_P4_to_P7 = np.dot(evts[goodEvts,:],u[:,4:8])
swp.splom(evts_good_P4_to_P7.transpose(),
          ['PC 4','PC 5','PC 6','PC 7'],
          marker='.',linestyle='None',
          alpha=0.2,ms=1) 

Fig5.png

Figure 15: Scatter plot matrix of the projections of the good events in evtsE onto the planes defined by PCs 4 to 7.

There is no clear strucuture left (except perhaps along PC6) suggesting strongly that we won't gain anything as far as clustering quality is concerned by keeping more than the first 4 PCs.

8.4. Dynamic visualization of the data with GGobi

The best way to discern structures in "high dimensional" data is to dynamically visualize them. To this end, the tool of choice is GGobi, an open source software available on Linux, Windows and MacOS. We start by exporting our data in csv format to our disk:

import csv
f = open('evts.csv','w')
w = csv.writer(f)
w.writerows(np.dot(evts[goodEvts,:],u[:,:8]))
f.close()

The following terse procedure should allow the reader to get going with GGobi:

  • Launch GGobi
  • In menu: File -> Open, select evtsE.csv.
  • Since the glyphs are rather large, start by changing them for smaller ones:
    • Go to menu: Interaction -> Brush.
    • On the Brush panel which appeared check the Persistent box.
    • Click on Choose color & glyph....
    • On the chooser which pops out, click on the small dot on the upper left of the left panel.
    • Go back to the window with the data points.
    • Right click on the lower right corner of the rectangle which appeared on the figure after you selected Brush.
    • Dragg the rectangle corner in order to cover the whole set of points.
    • Go back to the Interaction menu and select the first row to go back where you were at the start.
  • Select menu: View -> Rotation.
  • Adjust the speed of the rotation in order to see things properly.

We easily discern 10 rather well separated clusters. Meaning that an automatic clustering with 10 clusters on the first 3 principal components should do the job.

9. Clustering with k-means

9.1. Using 10 centers

Since our dynamic visualization shows 10 well separated clusters in 3 dimension, a simple k-means should do the job. We are using here function kmeans2 from scipy, using the supposed 'optimal' initialization, minit set at '++', and specifying the seed:

from scipy.cluster.vq import kmeans2
centroid_10, c10 = kmeans2(np.dot(evts[goodEvts,:],u[:,0:3]),
                           k=10, minit='++', seed=20110928)

The sklearn module also contains k-means implementation like the KMeans method. It can be considered slightly more robust since it offers the possibility to run the optimization multiple times from different starting points and keep the best. We used the scipy kmeans2 function in this document, but feel free to use KMeans if you want and if you have sklearn installed. Here you would call:

from sklearn.cluster import KMeans
km10 = KMeans(n_clusters=10, init='k-means++', n_init=100, max_iter=100)
km10.fit(np.dot(evts[goodEvts,:],u[:,0:3]))
c10 = km10.fit_predict(np.dot(evts[goodEvts,:],u[:,0:3]))

In order to facilitate comparison when models with different numbers of clusters or when different models are used (see below), clusters are sorted by 'size'. The size is defined here as the sum of the absolute value of the median of the cluster (an L1 norm):

cluster_median = list([(i,
                        np.apply_along_axis(np.median,0,
                                            evts[goodEvts,:][c10 == i,:]))
                                            for i in range(10)
                                            if sum(c10 == i) > 0])
cluster_size = list([np.sum(np.abs(x[1])) for x in cluster_median])
new_order = list(reversed(np.argsort(cluster_size)))
new_order_reverse = sorted(range(len(new_order)), key=new_order.__getitem__)
c10b = [new_order_reverse[i] for i in c10]

9.2. Using 9 centers

To make sure and get insight we repeat the procedure with 9 clusters:

centroid_9, c9 = kmeans2(np.dot(evts[goodEvts,:],u[:,0:3]),
                         k=9, minit='++', seed=20110928)
cluster_median9 = list([(i,
                         np.apply_along_axis(np.median,0,
                                             evts[goodEvts,:][c9 == i,:]))
                        for i in range(9)
                        if sum(c9 == i) > 0])
cluster_size9 = list([np.sum(np.abs(x[1])) for x in cluster_median9])
new_order9 = list(reversed(np.argsort(cluster_size9)))
new_order_reverse9 = sorted(range(len(new_order9)), key=new_order9.__getitem__)
c9b = [new_order_reverse9[i] for i in c9]

9.3. Using 11 centers

We also do it with one extra cluster:

centroid_11, c11 = kmeans2(np.dot(evts[goodEvts,:],u[:,0:3]),
                         k=11, minit='++', seed=20110928)
cluster_median11 = list([(i,
                         np.apply_along_axis(np.median,0,
                                             evts[goodEvts,:][c11 == i,:]))
                        for i in range(11)
                        if sum(c11 == i) > 0])
cluster_size11 = list([np.sum(np.abs(x[1])) for x in cluster_median11])
new_order11 = list(reversed(np.argsort(cluster_size11)))
new_order_reverse11 = sorted(range(len(new_order11)), key=new_order11.__getitem__)
c11b = [new_order_reverse11[i] for i in c11]

9.4. Comparison

It is usually interesting to compare the number of events attributed to each cluster in models with different number of clusters (since the behavior described next tends to be robust). Of course this makes sense only if the cluster centers have been first ordered. We do the counts with:

print("Nb of clusters:     9,    10,    11")
print("-----------------------------------")
for i in range(11):
    msg = "Cluster {0:6}: {1:5}, {2:5}, {3:5}"
    print(msg.format(i, c9b.count(i), c10b.count(i), c11b.count(i)))
print("-----------------------------------")
Nb of clusters:     9,    10,    11
-----------------------------------
Cluster      0:   139,   139,   139
Cluster      1:   117,   117,    74
Cluster      2:    83,    81,    43
Cluster      3:    71,    71,    81
Cluster      4:   199,   199,    71
Cluster      5:   164,    46,   199
Cluster      6:   244,   122,    46
Cluster      7:   416,   242,   122
Cluster      8:   448,   416,   242
Cluster      9:     0,   448,   415
Cluster     10:     0,     0,   449
-----------------------------------

What we see here is that when the number of clusters increases, the k-means tends to split particular clusters like cluster 1 in the models with 9 and 10 clusters that gets split into clusters 1 and 2 in the model with 11 clusters. So for the largest clusters, the correspondence from one model to the next is typically easily done.

9.5. Cluster specific plots

Looking at the first 5 clusters (of the model with 10 clusters) we get Fig. 16 with:

plt.subplot(511)
swp.plot_events(evts[goodEvts,:][np.array(c10b) == 0,:])
plt.ylim([-22,16])
plt.subplot(512)
swp.plot_events(evts[goodEvts,:][np.array(c10b) == 1,:])
plt.ylim([-22,16])
plt.subplot(513)
swp.plot_events(evts[goodEvts,:][np.array(c10b) == 2,:])
plt.ylim([-22,16])
plt.subplot(514)
swp.plot_events(evts[goodEvts,:][np.array(c10b) == 3,:])
plt.ylim([-22,16])
plt.subplot(515)
swp.plot_events(evts[goodEvts,:][np.array(c10b) == 4,:])
plt.ylim([-22,16])

events-clusters0to4.png

Figure 16: First 5 clusters. Cluster 0 at the top, cluster 4 at the bottom. Red, cluster specific central / median event. Blue, cluster specific MAD.

We see that some superposed events remain on all the plot, this why it is very important to estimate the 'reference' waveform of each neuron with the median. We also see that the MAD increases systematically where the derivative of the reference waveform (in red) is large in absolute value; this is a sampling jitter effect. At that stage, it happens that some clusters contains very events, it is then a good idea to restart the kmeans procedure with a different seed. Looking at the last 5 clusters we get Fig. 17 with:

plt.subplot(511)
swp.plot_events(evts[goodEvts,:][np.array(c10b) == 5,:])
plt.ylim([-12,10])
plt.subplot(512)
swp.plot_events(evts[goodEvts,:][np.array(c10b) == 6,:])
plt.ylim([-12,10])
plt.subplot(513)
swp.plot_events(evts[goodEvts,:][np.array(c10b) == 7,:])
plt.ylim([-12,10])
plt.subplot(514)
swp.plot_events(evts[goodEvts,:][np.array(c10b) == 8,:])
plt.ylim([-12,10])
plt.subplot(515)
swp.plot_events(evts[goodEvts,:][np.array(c10b) == 9,:])
plt.ylim([-12,10])

events-clusters5to9.png

Figure 17: Last 5 clusters. Cluster 5 at the top, cluster 9 at the bottom. Red, cluster specific central / median event. Blue, cluster specific MAD. Notice the change in ordinate scale compared to the previous figure.

9.6. Results inspection with GGobi

We start by checking our clustering quality with GGobi. To this end we export the data and the labels of each event:

f = open('evtsSorted.csv','w')
w = csv.writer(f)
w.writerows(np.concatenate((np.dot(evts[goodEvts,:],u[:,:8]),
                            np.array([c10b]).T),
                            axis=1))
f.close()

An again succinct description of how to do the dynamical visual check is:

  • Load the new data into GGobi like before.
  • In menu: Display -> New Scatterplot Display, select evtsEsorted.csv.
  • Change the glyphs like before.
  • In menu: Tools -> Color Schemes, select a scheme with 10 colors, like Spectral, Spectral 10.
  • In menu: Tools -> Automatic Brushing, select evtsEsorted.csv tab and, within this tab, select variable c10b. Then click on Apply.
  • Select View -> Rotation like before and see your result.

9.7. More detailed analysis of the waveforms

Risk of missed detections

The first thing we want to look at is how far beyond threshold are the valleys of the filtered version (since detection is done on the filtered data) of 'reference' waveform of each cluster. To that end, we define a function that returns the median waveform (the 'reference') of each cluster computed from the filtered (and MAD normalized) data.

def filtered_reference(cluster_index,
                       labels_lst,
                       spike_pos,
                       data,
                       filter_length,
                       before,
                       after):
    """Computes the median of the events whose label is 'cluster_index'
    on a filtered and MAD normalized version of 'data'.

    Parameters
    ----------
    cluster_index: an integer between 0 and the number of clusters - 1
    labels_list: a list of labels containing the cluster to which events
                 have been attributed
    spike_pos: an array of spike indices (same length as labels_list)
    data: the data
    filter_length: a positive integer
    before: positive integer, the number of points before the 'peak' in
            the window (see mk_events)
    after: positive integer, the number of points after the 'peak' in
            the window (see mk_events)

    Returns
    -------
    An array with the 'reference' waveform.
    """
    data_filtered = [d for d in data]
    data_filtered = apply(lambda x:
                      fftconvolve(x,np.ones(filter_length)/filter_length,'same'),
                      1,np.array(data_filtered))
    dfiltered_mad = apply(swp.mad,1,data_filtered)
    data_filtered = (data_filtered.transpose() / \
                 dfiltered_mad_original).transpose()
    sp = spike_pos[[labels_lst[i]==cluster_index for i in range(len(labels_lst))]]
    evts = swp.mk_events(sp,data_filtered,before,after)
    res = np.apply_along_axis(np.median,0,evts)
    return res

Next, we plot on Fig. 18 the filtered reference of each cluster on the same scale together with a red dashed line at detection threshold level and a blue dotted at the threshold minus 3 times the MAD (deviations due to background noise should never exceed this level):

plt.subplot(5,2,1)
ref = filtered_reference(0,c10b,sp0[goodEvts],data,5,14,30)
ref_min = np.floor(min(ref))
ref_max = np.ceil(max(ref))
plt.plot(ref,color='black',lw=1)
plt.axhline(y=-4,color="red",linestyle="dashed")
plt.axhline(y=-7,color="blue",linestyle="dotted")
plt.ylim([ref_min,ref_max])
plt.axis('off')
plt.title('Cluster '+str(0))
for i in range(1,10):
    plt.subplot(5,2,i+1)
    ref = filtered_reference(i,c10b,sp0[goodEvts],data,5,14,30)
    plt.plot(ref,color='black',lw=1)
    plt.axhline(y=-4,color="red",linestyle="dashed")
    plt.axhline(y=-7,color="blue",linestyle="dotted")
    plt.ylim([ref_min,ref_max])
    plt.axis('off')
    plt.title('Cluster '+str(i))

filtered-reference.png

Figure 18: The 'reference' waveform (the median) of each cluster computed from the filtered and MAD normalized version of the data (the version of the data on which the detection was performed). Dashed red, the detection threshold (-4); dotted blue, detection threshold minus 3 times the MAD.

We see that the lowest valley of the last two clusters is just at threshold level, meaning that only half of the events from these clusters are going to be detected. The lowest valley of cluster 7 is below threshold but above the blue line meaning that a fraction of the events should be missed. For the seven other clusters, the detection should be fine (that is, we should not miss events). So events from clusters 7, 8 and 9 should not be used for further analysis (spike trains, etc), at least not as spikes from well identified and properly detected neurons.

Waveform variability

Our discussion of the sources of variability, when we looked at the results of the principal components analysis, lead us to conclude that variability mainly arises from amplitude differences as opposed to shape differences. We can check that further by comparing the normalized (valley minimum at -1) reference waveform of each cluster on each recording site with the mean normalized waveform. This is what we will now do for the 7 'good' clusters. We start by reordering the waveforms in cluster_median according to their size:

cluster_median_o = [cluster_median[new_order[i]][1] for i in range(10)]

We normalize next the 'sub-waveforms' (the waveforms from each recording site) such that the valley depth is -1and we cut the 180 points long waveforms into 4, 45 points long ones, one per recording site:

cluster_median_n = [[cluster_median_o[i][45*j:45*(j+1)]/\
                     abs(min(cluster_median_o[i][45*j:45*(j+1)]))
                     for j in range(4)] for i in range(10)]

We now compute the mean normalized sub-waveform from the first seven clusters:

cluster_median_mean = np.zeros(45)
for i in range(7):
    for j in range(4):
        cluster_median_mean += cluster_median_n[i][j]
cluster_median_mean /= 28

We can now construct our figure 19:

for i in range(7):
    plt.subplot(3,3,i+1)
    plt.plot(cluster_median_mean,color='red',lw=2)
    plt.plot(np.array(cluster_median_n[i]).transpose(),color='black')
    plt.ylim([-1.1,0.9])
    plt.axis('off')
    plt.title('Cluster '+str(i))
    

fig-shape-comparison.png

Figure 19: Normalized reference waveform on each recording site for each of the 7 'good' clusters / neurons (black) and mean of the 28 normalized waveforms (thick red).

We see that the sub-waveforms of a given cluster are essentially all the same (keep in mind that the normalization 'amplifies' the noise in some cases) on the 4 recording sites. They are also very close to the global mean sub-waveform. This confirms our analysis of the PCA results. We have basically a single waveform that varies only a little bit from neuron to neuron, and the sub-waveforms of a given neuron / cluster are scaled version of this single waveform.

10. Dealing with the sampling jitter and its consequences

our superposition resolution approach illustrated in the next section consists in subtracting the 'best' waveform every time a putative event is detected. For this approach to work smoothly we must deal with misalignment problems due to the finite (time) sampling of our data. This section illustrates the problem with one of the events detected on our data set. Cluster 3 is the cluster exhibiting the largest sampling jitter effects, since it has the largest time derivative, in absolute value, of its median event . This is seen when we superpose the 13th event from this cluster with the median event (remember that we start numbering at 0). So we get first our estimate for center or template of cluster 3:

c3_median = apply(np.median,0,evts[goodEvts,:][np.array(c10b)==3,:])

And we do the plot (Fig. 20):

plt.plot(c3_median[:],color='red')
plt.plot(evts[goodEvts,:][np.array(c10b)==3,:][13,:],color='black')

JitterIllustrationCluster3Event13.png

Figure 20: The median event of cluster 3 (red) together with event 13 of the same cluster (black).

A Taylor expansion shows that if we write g(t) the observed 13th event, δ the sampling jitter and f(t) the actual waveform of the event then:

\begin{equation} g(t) = f(t+\delta{}) + \epsilon{}(t) \approx f(t) + \delta{} \, f'(t) + \delta{}^2/2 \, f''(t) + \epsilon{}(t) \, ; \end{equation}

where \(\epsilon{}\) is a Gaussian process and where \(f'\) and \(f''\) stand for the first and second time derivatives of \(f\). Therefore, if we can get estimates of \(f'\) and \(f''\) we should be able to estimate \(\delta{}\) by linear regression (if we neglect the \(\delta{}^2\) term as well as the potentially non null correlation in \epsilon{}) or by non linear regression (if we keep the latter). We start by getting the derivatives estimates:

dataD = apply(lambda x: fftconvolve(x,np.array([1,0,-1])/2.,'same'),
              1, data)
evtsED = swp.mk_events(sp0,dataD,14,30)
dataDD = apply(lambda x: fftconvolve(x,np.array([1,0,-1])/2.,'same'),
               1, dataD)
evtsEDD = swp.mk_events(sp0,dataDD,14,30)
c3D_median = apply(np.median,0,
                   evtsED[goodEvts,:][np.array(c10b)==3,:])
c3DD_median = apply(np.median,0,
                    evtsEDD[goodEvts,:][np.array(c10b)==3,:])

We then get something like Fig. 21:

plt.plot(evts[goodEvts,:][np.array(c10b)==3,:][13,:]-\
         c3_median,color='black',lw=2)
plt.plot(1.3*c3D_median,color='red',lw=2)

JitterIllustrationCluster3Event13b.png

Figure 21: The median event of cluster 3 subtracted from event 13 of the same cluster (black) and 1.3 times the first derivative of the median event (red).

If we neglect the \(\delta{}^2\) term we quickly arrive at:

\begin{equation} \hat{\delta{}} = \frac{\mathbf{f'} \cdot (\mathbf{g} -\mathbf{f})}{\| \mathbf{f'} \|^2} \, ; \end{equation}

where the 'vectorial' notation like \(\mathbf{a} \cdot \mathbf{b}\) stands here for: \[ \sum_{i=0}^{179} a_i b_i \, . \]

For the 13th event of the cluster we get:

delta_hat = np.dot(c3D_median,
                   evts[goodEvts,:][np.array(c10b)==3,:][13,:]-\
                   c3_median)/np.dot(c3D_median,c3D_median)
delta_hat
1.3267017731512747

We can use this estimated value of delta_hat as an initial guess for a procedure refining the estimate using also the \(\delta{}^2\) term. The obvious quantity we should try to minimize is the residual sum of square, RSS defined by: \[ \mathrm{RSS}(\delta{}) = \| \mathbf{g} - \mathbf{f} - \delta{} \, \mathbf{f'} - \delta{}^2/2 \, \mathbf{f''} \|^2 \; . \] We can define a function returning the RSS for a given value of \(\delta{}\) as well as an event evt a cluster center (median event of the cluster) center and its first two derivatives, centerD and centerDD:

def rss_fct(delta,evt,center,centerD,centerDD):
    return np.sum((evt - center - delta*centerD - delta**2/2*centerDD)**2)

To create quickly a graph of the RSS as a function of \(\delta{}\) for the specific case we are dealing with now (13th element of cluster 3) we create a vectorized or universal function version of the rss_for_alignment we just defined:

urss_fct = np.frompyfunc(lambda x:
                         rss_fct(x,
                                 evts[goodEvts,:]\
                                 [np.array(c10b)==3,:][13,:],
                                 c3_median,c3D_median,c3DD_median),1,1)

We then get the Fig. 22 with:

plt.subplot(1,2,1)
dd = np.arange(-5,5,0.05)
plt.plot(dd,urss_fct(dd),color='black',lw=2)
plt.subplot(1,2,2)
dd_fine = np.linspace(delta_hat-0.5,delta_hat+0.5,501)
plt.plot(dd_fine,urss_fct(dd_fine),color='black',lw=2)
plt.axvline(x=delta_hat,color='red')

JitterIllustrationCluster3Event13c.png

Figure 22: The RSS as a function of \(\delta{}\) for event 13 of cluster 3. Left, \(\delta{} \in [-5,5]\); right, \(\delta{} \in [\hat{\delta{}}-0.5,\hat{\delta{}}+0.5]\) and the red vertical line shows \(\hat{\delta{}}\).

The left panel of the above figure shows that our initial guess for \(\hat{\delta{}}\) is not bad but still approximately 0.2 units away from the actual minimum. The classical way to refine our \(\delta{}\) estimate—in 'nice situations' where the function we are trying to minimize is locally convex—is to use the Newton-Raphson algorithm which consists in approximating locally the 'target function' (here our RSS function) by a parabola having locally the same first and second derivatives, before jumping to the minimum of this approximating parabola. If we develop our previous expression of \(\mathrm{RSS}(\delta{})\) we get: \[ \mathrm{RSS}(\delta{}) = \| \mathbf{h} \|^2 - 2\, \delta{} \, \mathbf{h} \cdot \mathbf{f'} + \delta{}^2 \, \left( \|\mathbf{f'}\|^2 - \mathbf{h} \cdot \mathbf{f''}\right) + \delta{}^3 \, \mathbf{f'} \cdot \mathbf{f''} + \frac{\delta{}^4}{4} \|\mathbf{f''}\|^2 \, ; \] where \(\mathbf{h}\) stands for \(\mathbf{g} - \mathbf{f}\). By differentiation with respect to \(\delta{}\) we get: \[ \mathrm{RSS}'(\delta{}) = - 2\, \mathbf{h} \cdot \mathbf{f'} + 2 \, \delta{} \, \left( \|\mathbf{f'}\|^2 - \mathbf{h} \cdot \mathbf{f''}\right) + 3 \, \delta{}^2 \, \mathbf{f'} \cdot \mathbf{f''} + \delta{}^3 \|\mathbf{f''}\|^2 \, . \] And a second differentiation leads to: \[ \mathrm{RSS}''(\delta{}) = 2 \, \left( \|\mathbf{f'}\|^2 - \mathbf{h} \cdot \mathbf{f''}\right) + 6 \, \delta{} \, \mathbf{f'} \cdot \mathbf{f''} + 3 \, \delta{}^2 \|\mathbf{f''}\|^2 \, . \] The equation of the approximating parabola at \(\delta{}^{(k)}\) is then: \[ \mathrm{RSS}(\delta{}^{(k)} + η) \approx \mathrm{RSS}(\delta{}^{(k)}) + η \, \mathrm{RSS}'(\delta{}^{(k)}) + \frac{η^2}{2} \, \mathrm{RSS}''(\delta{}^{(k)})\; , \] and its minimum—if \(\mathrm{RSS}''(\delta{})\) > 0—is located at: \[ \delta{}^{(k+1)} = \delta{}^{(k)} - \frac{\mathrm{RSS}'(\delta{}^{(k)})}{\mathrm{RSS}''(\delta{}^{(k)})} \; . \] Defining functions returning the required derivatives:

def rssD_fct(delta,evt,center,centerD,centerDD):
    h = evt - center
    return -2*np.dot(h,centerD) + \
      2*delta*(np.dot(centerD,centerD) - np.dot(h,centerDD)) + \
      3*delta**2*np.dot(centerD,centerDD) + \
      delta**3*np.dot(centerDD,centerDD)

def rssDD_fct(delta,evt,center,centerD,centerDD):
    h = evt - center
    return 2*(np.dot(centerD,centerD) - np.dot(h,centerDD)) + \
      6*delta*np.dot(centerD,centerDD) + \
      3*delta**2*np.dot(centerDD,centerDD)

we can get a graphical representation (Fig. 23) of a single step of the Newton-Raphson algorithm:

rss_at_delta0 = rss_fct(delta_hat,
                        evts[goodEvts,:][np.array(c10b)==3,:][13,:],
                        c3_median,c3D_median,c3DD_median)
rssD_at_delta0 = rssD_fct(delta_hat,
                          evts[goodEvts,:][np.array(c10b)==3,:][13,:],
                          c3_median,c3D_median,c3DD_median)
rssDD_at_delta0 = rssDD_fct(delta_hat,
                            evts[goodEvts,:][np.array(c10b)==3,:]\
                            [13,:],c3_median,c3D_median,c3DD_median)
delta_1 = delta_hat - rssD_at_delta0/rssDD_at_delta0
plt.plot(dd_fine,urss_fct(dd_fine),color='black',lw=2)
plt.axvline(x=delta_hat,color='red')
plt.plot(dd_fine,
         rss_at_delta0 + (dd_fine-delta_hat)*rssD_at_delta0 + \
         (dd_fine-delta_hat)**2/2*rssDD_at_delta0,color='blue',lw=2)
plt.axvline(x=delta_1,color='grey')

JitterIllustrationCluster3Event13d.png

Figure 23: The RSS as a function of \delta{} for event 13 of cluster 3 (black), the red vertical line shows \(\hat{\delta{}}\). In blue, the approximating parabola at \(\hat{\delta{}}\). The grey vertical line shows the minimum of the approximating parabola.

Subtracting the second order in \(\delta{}\) approximation of \(f(t+\delta{})\) from the observed 13th event of cluster 3 we get Fig. 24:

plt.plot(evts[goodEvts,:][np.array(c10b)==3,:][13,:]-\
         c3_median-delta_1*c3D_median-delta_1**2/2*c3DD_median,
         color='red',lw=2)
plt.plot(evts[goodEvts,:][np.array(c10b)==3,:][13,:],
         color='black',lw=2)
plt.plot(c3_median+delta_1*c3D_median+delta_1**2/2*c3DD_median,
         color='blue',lw=1)

JitterIllustrationCluster3Event13e.png

Figure 24: Event 13 of cluster 3 (black), second order approximation of f(t+\delta{}) (blue) and residual (red) for \(\delta{}\)—obtained by a succession of a linear regression (order 1) and a single Newton-Raphson step—equal to: 1.2352313930704102.

11. Classification by spike 'peeling': a 'Brute force' superposition resolution

11.1. Where are we?

We have just spent a lot of time working on only 29 seconds of data, while the experiment from which these data are taken consists of about 1.5 hours of recordings. Before you despair, you should realize that these recordings are stable, there are small drifts, but these are typically slow. This means that once we have our catalog of motifs, things can go much faster. The way I do it—this is fully detailed on a GitHub repository page dedicated to the analysis of this experiment (it's done with R, but you should have no problem to follow)—is as follows: successive acquisitions (which are typically 29 seconds long) are sorted one after the other. The sorting is done from the motifs of the previous acquisition in the way I will detail in the next section. You will see that this sorting results in a collection of action potential times for each 'neuron' or motif in the catalog, so a number of events for each as well as a number of unclassified events is obtained for each acquisition sequence. At the end of this classification of 29 seconds of data, the motifs are updated, typically by taking the average of the motif from the previous 29 seconds and the one calculated from the events that have just been assigned to it. In this way, slow drifts are simply taken into account. If an important drift takes place with the consequence that the events of a neuron / motif are no longer attributed to it, we can see simultaneously that the number of unclassified events increases, so we have a criterion indicating the need to recalculate, as we have done since the beginning of this document, our motifs catalog. I point out that this approach is not the most commonly used, for some reason—absurd in my opinion—the bulk of spike sorting practitioners analyze all their data 'en bloc', which obviously makes everything more complicated…

11.2. Superposition resolution through 'peeling'

We will solve the (most obvious) overlaps / superpositions by recursive 'peeling':

  1. Events are detected and cut from the raw data or from an 'already peeled' version of them.
  2. The closest motif (in the sense of Euclidean distance) among those in the catalog is found—jitter is systematically compensated when distances are computed—.
  3. If the sum of squared residuals (RSS), i.e.: (event - nearest motif)\(^2\) is smaller than (event)\(^2\), the nearest motif is subtracted from the data—with jitter compensation—.
  4. We return to step 1, or we stop if there are no more events attributable to one of the motifs in the catalog.

11.3. Some remarks

The Euclidean distance is used to measure the distance between a motif and an event, this is what we have always done until now. We could be more rigorous and take into account the auto-correlation of the noise, i.e., if e (column vector) is an event, m a motif, we compute : (e-m)\(^T\)(e-m) (square of the Euclidean distance); instead, we can estimate \(\Sigma\) the noise covariance matrix (from the data chunks between detected events) and replace the square of the Euclidean distance above by the square of the Mahalanobis distance: (e-m)\(^T \Sigma{}^{-1}\)(e-m). In practice this is computationally expensive without changing the results, which is why I use the Euclidean distance.

I accept the proposal of subtraction, as long as it decreases the L2 norm computed on the cut of the considered event. This can result in overfitting, especially when small amplitude motifs (those of poorly identified neurons) are used. In practice this is not crucial because the analysis of the action potential sequences—the data analysis that follows the sorting of the action potentials—normally focuses only on the sequences of the 'good' neurons. Nevertheless, it would be possible to refine the described procedure by accepting a subtraction if it decreases the L2 norm (as I am doing now) without getting an unlikely value given the noise level—again, the distribution of the noise L2 norm can be estimated from the trace pieces between detected events—.

11.4. Motifs for subtraction must go back to zero

Last important point before we go ahead; if we want our subtraction to be valid, we must make sure that our motifs return to zero on each side of the valley. This is clearly not the case (at least) for our largest motifs because the length of the cuts has so far been chosen to optimize the discriminability and we need relatively little data to classify correctly. So we start by building a new catalog with longer motifs. This catalog built by calling the mk_center_dictionary function will contain in addition to the motif, its derivative, its second derivative, etc., i.e. all the quantities needed to compensate for jitter. We now build our catalog with motifs that start 50 points before the reference time and extend 80 after it:

centers = { "Cluster " + str(i) :
            swp.mk_center_dictionary(sp0[goodEvts][np.array(c10b)==i],
                                     np.array(data),
                                     before=49, after=80)
            for i in range(10)}

To make sure that our 'motifs for peeling' are long enough, we plot the first five on a common scale together with their time derivative (Fig. 25).

long_motif_range = [min(centers['Cluster 0']['center']),
                    max(centers['Cluster 0']['center'])]
for i in range(5):
    plt.subplot(5,1,i+1)
    plt.plot(centers['Cluster '+str(i)]['center'],color='blue')
    plt.plot(centers['Cluster '+str(i)]['centerD'],color='orange')
    plt.ylim(long_motif_range)
    plt.axis('off')
    plt.title('Cluster ' + str(i))

first-5-long-motifs.png

Figure 25: First 5 long motifs (blue) and their time derivatives (orange) on the same scale. The traces do indeed go back to zero on both sides of the extrema..

Although is not really necessary since the amplitudes are smaller, we do the same for the last 5 motifs (Fig. 26).

for i in range(5,10):
    plt.subplot(5,1,i-4)
    plt.plot(centers['Cluster '+str(i)]['center'],color='blue')
    plt.plot(centers['Cluster '+str(i)]['centerD'],color='orange')
    plt.ylim(long_motif_range)
    plt.axis('off')
    plt.title('Cluster ' + str(i))

last-5-long-motifs.png

Figure 26: Last 5 long motifs (blue) and their time derivatives (orange) on the same scale as Fig. 25. The traces do indeed go back to zero on both sides of the extrema.

These last two figure show that our 'long' motifs are long enough for subtraction.

11.5. First peeling

Function classify_and_align_evt is used next. For each detected event, it matches the closest template, correcting for the jitter, if the closest template is close enough:

swp.classify_and_align_evt(sp0[0],np.array(data),centers)
['Cluster 0', 198, 0.35016584628593517]

We can use the function on every detected event. A trick here is to store the matrix version of the data in order to avoid the conversion of the list of vectors (making the data of the different channels) into a matrix for each detected event:

data0 = np.array(data) 
round0 = [swp.classify_and_align_evt(sp0[i],data0,centers)
          for i in range(len(sp0))]

We check how many events are attributed to each cluster and how many end up not classified '?':

nb_total = 0
print('Number of events attributed to each cluster / neuron')
print('----------------------------------------------------')
for i in range(10):
    nb = len([x[1] for x in round0 if x[0] == 'Cluster '+str(i)])
    nb_total += nb
    msg = 'Cluster {0:3}: {1:5}'.format(i,nb)
    print(msg)
nb = len([x[1] for x in round0 if x[0] == '?'])
nb_total += nb
msg = '          ?: {0:5}'.format(nb)
print(msg)
msg = 'Total      : {0:5}'.format(nb_total)
print(msg)
print('----------------------------------------------------')

Using function predict_data, we create an ideal data trace given events' positions, events' origins and a clusters' catalog:

pred0 = swp.predict_data(round0,centers,data_length=data0.shape[1])

We then subtract the prediction (pred0) from the data (data0) to get the "peeled" data (data1):

data1 = data0 - pred0

We can compare the original data with the result of the "first peeling" to get Fig. 27:

plt.plot(tt, data0[0,], color='black',lw=0.5)
plt.plot(tt, data1[0,], color='red',lw=0.3)
plt.plot(tt, data0[1,]-20, color='black',lw=0.5)
plt.plot(tt, data1[1,]-20, color='red',lw=0.3)
plt.plot(tt, data0[2,]-40, color='black',lw=0.5)
plt.plot(tt, data1[2,]-40, color='red',lw=0.3)
plt.plot(tt, data0[3,]-60, color='black',lw=0.5)
plt.plot(tt, data1[3,]-60, color='red',lw=0.3)
plt.xlabel('Time (s)')
plt.xlim([1.8,1.9])
plt.ylim([-70,10])
plt.axis('off')

FirstPeeling.png

Figure 27: 100 ms of the locust data set. Black, original data; red, after first peeling.

11.6. Second peeling

We then take data1 as our former data0 and we repeat the procedure. We do it with slight modifications: detection is done on a single recording site and a shorter filter length is used before detecting the events. Doing detection on a single site (here site 0) allows us to correct some drawbacks of our crude spike detection method. When we used it the first time we summed the filtered and rectified versions of the data before looking at peaks. This summation can lead to badly defined spike times when two neurons that are large on different recording sites, say site 0 and site 1 fire at nearly the same time. The summed event can then have a peak in between the two true peaks and our jitter correction cannot resolve that. We are therefore going to perform detection on the different sites. The jitter estimation and the subtraction are always going to be done on the 4 recording sites:

data_filtered = data1
data_filtered = apply(lambda x:
                      fftconvolve(x,np.array([1,1,1,1,1])/5.,'same'),
                      1,np.array(data_filtered))
data_filtered = (data_filtered.transpose() / \
                 dfiltered_mad_original).transpose()
data_filtered[data_filtered > -4] = 0
sp1 = swp.peak(-data_filtered[0,:])

We classify the events and obtain the new prediction and the new "data":

round1 = [swp.classify_and_align_evt(sp1[i],data1,centers)
          for i in range(len(sp1))]
pred1 = swp.predict_data(round1,centers,data_length=data1.shape[1])
data2 = data1 - pred1

We can check how many events are attributed to each cluster:

nb_total = 0
print('Number of events attributed to each cluster / neuron')
print('----------------------------------------------------')
for i in range(10):
    nb = len([x[1] for x in round1 if x[0] == 'Cluster '+str(i)])
    nb_total += nb
    msg = 'Cluster {0:3}: {1:5}'.format(i,nb)
    print(msg)
nb = len([x[1] for x in round1 if x[0] == '?'])
nb_total += nb
msg = '          ?: {0:5}'.format(nb)
print(msg)
msg = 'Total      : {0:5}'.format(nb_total)
print(msg)
print('----------------------------------------------------')

We see that most of the events attributed now are going to the 'bad' clusters (7, 8, 9); that strongly suggests that the job is done or nearly done.

We can compare the first peeling with the second one (Fig. 28):

plt.plot(tt, data1[0,], color='black',lw=0.5)
plt.plot(tt, data2[0,], color='red',lw=0.3)
plt.plot(tt, data1[1,]-20, color='black',lw=0.5)
plt.plot(tt, data2[1,]-20, color='red',lw=0.3)
plt.plot(tt, data1[2,]-40, color='black',lw=0.5)
plt.plot(tt, data2[2,]-40, color='red',lw=0.3)
plt.plot(tt, data1[3,]-60, color='black',lw=0.5)
plt.plot(tt, data2[3,]-60, color='red',lw=0.3)
plt.xlabel('Time (s)')
plt.xlim([1.8,1.9])
plt.ylim([-70,10])
plt.axis('off')

SecondPeeling.png

Figure 28: 100 ms of the locust data set. Black, first peeling; red, second peeling.

We can also take a look at the non-classified events (Fig. 29):

nc_r1_pos = [x[1] for x in round1 if x[0] == '?']
nc_r1_evts = swp.mk_events(nc_r1_pos,data1,14,30)
swp.plot_events(nc_r1_evts,show_median=False,show_mad=False,events_lw=0.5)

non-classified-round1.png

Figure 29: The 8 non-classified events at the second peeling stage.

This confirms that there is nothing interesting left; we can stop here.

12. Spike trains extraction and a test

12.1. Spike trains

If we sort the action potentials, it is obviously not to obtain a peeled trace like the red trace in our penultimate figure (Fig. [[fig:SecondPeeling). What we are really interested in is to obtain the (time) sequences of action potentials and to study how these sequences are modified by the presentation of stimuli or to study whether the sequences of different neurons interact with each other. We will not go that far here, but we will do some basic checks on the sequences to validate the quality of our classification. First, we extract the sequences (and convert the spike times in ms without taking into account the jitter correction) with:

trains = [sorted([x[1]/15.0 for x in round0 if x[0] == 'Cluster '+str(i)]+
           [x[1]/15.0 for x in round1 if x[0] == 'Cluster '+str(i)])
          for i in range(10)]

12.2. Refractory period

As a simple test we will compute the inter spike intervals of each train and look at their 5-number summary (particularly the minimum that should be compatible with the refractory period):

isi_lst = [[st[i+1]-st[i] for i in range(len(st)-1)] for st in trains]
np.set_printoptions(precision=1)
[mquantiles(x,prob=[0,0.25,0.5,0.75,1]) for x in isi_lst]
[array([  22.6,   37.3,   54.2,  118.8, 2026.7]),
 array([  23.8,   45.3,   68.1,  301. , 1472.3]),
 array([  18.6,   33.3,   51.3,  226. , 3456.4]),
 array([  21.8,   56.1,   86.2,  415.9, 2724.1]),
 array([   5.9,   43.3,   61.7,  126.9, 1265.1]),
 array([  29.9,   73.1,  159.9,  676.1, 5435.4]),
 array([  25.3,   43. ,   83.9,  192.1, 1939.5]),
 array([7.3e-01, 3.9e+01, 5.5e+01, 9.1e+01, 9.9e+02]),
 array([  1.6,  25.2,  42. ,  71.4, 577.3]),
 array([  0. ,  21.3,  40.5,  75.8, 435. ])]

We see, as expected, that the first 7 neurons / clusters (first 7 rows above) give rise to spike trains exhibiting a proper refractory period (the first column gives the smallest inter spike interval in ms), while the last three have virtually no such period.

13. Individual function definitions

Short function are presented in 'one piece'. The longer ones are presented with their docstring first followed by the body of the function. To get the actual function you should replace the <<docstring>> appearing in the function definition by the actual doctring. This is just a direct application of the literate programming paradigm. More complicated functions are split into more parts with their own descriptions.

13.1. plot_data_list

We define a function, plot_data_list, making our raw data like displaying command lighter, starting with the docstring:

"""Plots data when individual recording channels make up elements
of a list.

Parameters
----------
data_list: a list of numpy arrays of dimension 1 that should all
           be of the same length (not checked).
time_axes: an array with as many elements as the components of
           data_list. The time values of the abscissa.
linewidth: the width of the lines drawing the curves.
color: the color of the curves.

Returns
-------
Nothing is returned, the function is used for its side effect: a
plot is generated. 
"""

Then the definition of the function per se:

def plot_data_list(data_list,
                   time_axes,
                   linewidth=0.2,
                   color='black'):
    <<plot_data_list-doctring>>
    nb_chan = len(data_list)
    data_min = [np.min(x) for x in data_list]
    data_max = [np.max(x) for x in data_list]
    display_offset = list(np.cumsum(np.array([0] +
                                             [data_max[i]-
                                              data_min[i-1]
                                             for i in
                                             range(1,nb_chan)])))
    for i in range(nb_chan):
        plt.plot(time_axes,data_list[i]-display_offset[i],
                 linewidth=linewidth,color=color)
    plt.yticks([])
    plt.xlabel("Time (s)")

13.2. peak

We define function peak which detects local maxima using an estimate of the derivative of the signal. Only putative maxima that are farther apart than minimal_dist sampling points are kept. The function returns a vector of indices. Its docstring is:

"""Find peaks on one dimensional arrays.

Parameters
----------
x: a one dimensional array on which scipy.signal.fftconvolve can
   be called.
minimal_dist: the minimal distance between two successive peaks.
not_zero: the smallest value above which the absolute value of
the derivative is considered not null.

Returns
-------
An array of (peak) indices is returned.
"""

And the function per se:

def peak(x, minimal_dist=15, not_zero=1e-3):
    <<peak-docstring>>
    ## Get the first derivative
    dx = scipy.signal.fftconvolve(x,np.array([1,0,-1])/2.,'same') 
    dx[np.abs(dx) < not_zero] = 0
    dx = np.diff(np.sign(dx))
    pos = np.arange(len(dx))[dx < 0]
    return pos[:-1][np.diff(pos) > minimal_dist]

13.3. cut_sgl_evt

Function mk_events (defined next) that we will use directly will call cut_sgl_evt. As its name says cuts a single event (an return a vector with the cuts on the different recording sites glued one after the other). Its docstring is:

"""Cuts an 'event' at 'evt_pos' on 'data'.
    
Parameters
----------
evt_pos: an integer, the index (location) of the (peak of) the
         event.
data: a matrix whose rows contains the recording channels.
before: an integer, how many points should be within the cut
        before the reference index / time given by evt_pos.
after: an integer, how many points should be within the cut
       after the reference index / time given by evt_pos.
    
Returns
-------
A vector with the cuts on the different recording sites glued
one after the other. 
"""

And the function per se:

def cut_sgl_evt(evt_pos,data,before=14, after=30):
    <<cut_sgl_evt-docstring>>
    ns = data.shape[0] ## Number of recording sites
    dl = data.shape[1] ## Number of sampling points
    cl = before+after+1 ## The length of the cut
    cs = cl*ns ## The 'size' of a cut
    cut = np.zeros((ns,cl))
    idx = np.arange(-before,after+1)
    keep = idx + evt_pos
    within = np.bitwise_and(0 <= keep, keep < dl)
    kw = keep[within]
    cut[:,within] = data[:,kw].copy()
    return cut.reshape(cs) 
  

13.4. mk_events

Function mk_events takes a vector of indices as its first argument and returns a matrix with has many rows as events. Its docstring is

"""Make events matrix out of data and events positions.
    
Parameters
----------
positions: a vector containing the indices of the events.
data: a matrix whose rows contains the recording channels.
before: an integer, how many points should be within the cut
        before the reference index / time given by evt_pos.
after: an integer, how many points should be within the cut
       after the reference index / time given by evt_pos.
    
Returns
-------
A matrix with as many rows as events and whose rows are the cuts
on the different recording sites glued one after the other. 
"""

And the function per se:

def mk_events(positions, data, before=14, after=30):
    <<mk_events-docstring>>
    res = np.zeros((len(positions),(before+after+1)*data.shape[0]))
    for i,p in enumerate(positions):
        res[i,:] = cut_sgl_evt(p,data,before,after)
    return res 

13.5. plot_events

In order to facilitate events display, we define an event specific plotting function starting with its docstring:

"""Plot events.
    
Parameters
----------
evts_matrix: a matrix of events. Rows are events. Cuts from
             different recording sites are glued one after the
             other on each row.
n_plot: an integer, the number of events to plot (if 'None',
        default, all are shown).
n_channels: an integer, the number of recording channels.
events_color: the color used to display events. 
events_lw: the line width used to display events. 
show_median: should the median event be displayed?
median_color: color used to display the median event.
median_lw: line width used to display the median event.
show_mad: should the MAD be displayed?
mad_color: color used to display the MAD.
mad_lw: line width used to display the MAD.

Returns
-------
Noting, the function is used for its side effect.
"""

And the function per se:

def plot_events(evts_matrix, 
                n_plot=None,
                n_channels=4,
                events_color='black', 
                events_lw=0.1,
                show_median=True,
                median_color='red',
                median_lw=0.5,
                show_mad=True,
                mad_color='blue',
                mad_lw=0.5):
    <<plot_events-docstring>>
    if n_plot is None:
        n_plot = evts_matrix.shape[0]

    cut_length = evts_matrix.shape[1] // n_channels 
    
    for i in range(n_plot):
        plt.plot(evts_matrix[i,:], color=events_color, lw=events_lw)
    if show_median:
        MEDIAN = np.apply_along_axis(np.median,0,evts_matrix)
        plt.plot(MEDIAN, color=median_color, lw=median_lw)

    if show_mad:
        MAD = np.apply_along_axis(mad,0,evts_matrix)
        plt.plot(MAD, color=mad_color, lw=mad_lw)
    
    left_boundary = np.arange(cut_length,
                              evts_matrix.shape[1],
                              cut_length*2)
    for l in left_boundary:
        plt.axvspan(l,l+cut_length-1,
                    facecolor='grey',alpha=0.5,edgecolor='none')
    plt.xticks([])
    return

13.6. plot_data_list_and_detection

We define a function, plot_data_list_and_detection, making our data and detection displaying command lighter. Its docstring:

"""Plots data together with detected events.
    
Parameters
----------
data_list: a list of numpy arrays of dimension 1 that should all
           be of the same length (not checked).
time_axes: an array with as many elements as the components of
           data_list. The time values of the abscissa.
evts_pos: a vector containing the indices of the detected
          events.
linewidth: the width of the lines drawing the curves.
color: the color of the curves.

Returns
-------
Nothing is returned, the function is used for its side effect: a
plot is generated. 
"""

And the function:

def plot_data_list_and_detection(data_list,
                                 time_axes,
                                 evts_pos,
                                 linewidth=0.2,
                                 color='black'):                             
    <<plot_data_list_and_detection-docstring>>
    nb_chan = len(data_list)
    data_min = [np.min(x) for x in data_list]
    data_max = [np.max(x) for x in data_list]
    display_offset = list(np.cumsum(np.array([0] +
                                             [data_max[i]-
                                              data_min[i-1] for i in
                                             range(1,nb_chan)])))
    for i in range(nb_chan):
        plt.plot(time_axes,data_list[i]-display_offset[i],
                 linewidth=linewidth,color=color)
        plt.plot(time_axes[evts_pos],
                 data_list[i][evts_pos]-display_offset[i],'ro')
    plt.yticks([])
    plt.xlabel("Time (s)")

13.7. mad

We define the mad function in one piece since it is very short:

def mad(x):
    """Returns the Median Absolute Deviation of its argument.
    """
    return np.median(np.absolute(x - np.median(x)))*1.4826

13.8. splom

We define a function constructing a scatterplot matrix. The code is taken from stack overflow:

def splom(data, names=[], **kwargs):
    """
    Plots a scatterplot matrix of subplots.  Each row of "data" is plotted
    against other rows, resulting in a nrows by nrows grid of subplots with the
    diagonal subplots labeled with "names".  Additional keyword arguments are
    passed on to matplotlib's "plot" command. Returns the matplotlib figure
    object containg the subplot grid.
    """
    numvars, numdata = data.shape
    fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
    fig.subplots_adjust(hspace=0.0, wspace=0.0)

    for ax in axes.flat:
        # Hide all ticks and labels
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

        if ax.get_subplotspec().is_first_col():
            ax.yaxis.set_ticks_position('left')
        if ax.get_subplotspec().is_last_col():
            ax.yaxis.set_ticks_position('right')
        if ax.get_subplotspec().is_first_row():
            ax.xaxis.set_ticks_position('top')
        if ax.get_subplotspec().is_last_row():
            ax.xaxis.set_ticks_position('bottom')
    
    # Plot the data.
    for i, j in zip(*np.triu_indices_from(axes, k=1)):
        for x, y in [(i,j), (j,i)]:
            # FIX #1: this needed to be changed from ...(data[x], data[y],...)
            axes[x,y].plot(data[y], data[x], **kwargs)

    # Label the diagonal subplots...
    if not names:
        names = ['x'+str(i) for i in range(numvars)]

    for i, label in enumerate(names):
        axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
                ha='center', va='center')

    # Turn on the proper x or y axes ticks.
    for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
        axes[j,i].xaxis.set_visible(True)
        axes[i,j].yaxis.set_visible(True)

    # FIX #2: if numvars is odd, the bottom right corner plot doesn't have the
    # correct axes limits, so we pull them from other axes
    if numvars%2:
        xlimits = axes[0,-1].get_xlim()
        ylimits = axes[-1,0].get_ylim()
        axes[-1,-1].set_xlim(xlimits)
        axes[-1,-1].set_ylim(ylimits)

    return fig

13.9. mk_aligned_events

We start with the chunk importing the required functions from the different modules (<<mk_aligned_events-import-functions>>):

from scipy.signal import fftconvolve
from numpy import apply_along_axis as apply
from scipy.spatial.distance import squareform

We then get the first and second derivatives of the data:

dataD = apply(lambda x: fftconvolve(x,np.array([1,0,-1])/2., 'same'),
              1, data)
dataDD = apply(lambda x: fftconvolve(x,np.array([1,0,-1])/2.,'same'),
               1, dataD)
    

Events are cut from the different data 'versions', derivatives of order 0, 1 and 2 (<<mk_aligned_events-get-events>>):

evts = mk_events(positions, data, before, after)
evtsD = mk_events(positions, dataD, before, after)
evtsDD = mk_events(positions, dataDD, before, after)    

A center or template is obtained by taking the pointwise median of the events we just got on the three versions of the data (<<mk_aligned_events-get-centers>>):

center = apply(np.median,0,evts)
centerD = apply(np.median,0,evtsD)
centerD_norm2 = np.dot(centerD,centerD)
centerDD = apply(np.median,0,evtsDD)
centerDD_norm2 = np.dot(centerDD,centerDD)
centerD_dot_centerDD = np.dot(centerD,centerDD)

Given an event, make a first order jitter estimation and compute the norm of the initial residual, h_order0_norm2, and of its first order jitter corrected version, h_order1_norm2 (<<mk_aligned_events-do-job-on-single-event-order1>>):

h = evt - center
h_order0_norm2 = sum(h**2)
h_dot_centerD = np.dot(h,centerD)
jitter0 = h_dot_centerD/centerD_norm2
h_order1_norm2 = sum((h-jitter0*centerD)**2)

If the residual's norm decrease upon first order jitter correction, try a second order one. At the end compare the norm of the second order jitter corrected residual (h_order2_norm2) with the one of the first order (h_order1_norm2). If the former is larger or equal than the latter, set the estimated jitter to its first order value (<<mk_aligned_events-do-job-on-single-event-order2>>):

h_dot_centerDD = np.dot(h,centerDD)
first = -2*h_dot_centerD + \
  2*jitter0*(centerD_norm2 - h_dot_centerDD) + \
  3*jitter0**2*centerD_dot_centerDD + \
  jitter0**3*centerDD_norm2
second = 2*(centerD_norm2 - h_dot_centerDD) + \
  6*jitter0*centerD_dot_centerDD + \
  3*jitter0**2*centerDD_norm2
jitter1 = jitter0 - first/second
h_order2_norm2 = sum((h-jitter1*centerD- \
                      jitter1**2/2*centerDD)**2)
if h_order1_norm2 <= h_order2_norm2:
    jitter1 = jitter0

And now the function's docstring (<<mk_aligned_events-docstring>>):

"""Align events on the central event using first or second order
Taylor expansion.

Parameters
----------
positions: a vector of indices with the positions of the
           detected events. 
data: a matrix whose rows contains the recording channels.
before: an integer, how many points should be within the cut
        before the reference index / time given by positions.
after: an integer, how many points should be within the cut
       after the reference index / time given by positions.
   
Returns
-------
A tuple whose elements are:
  A matrix with as many rows as events and whose rows are the
  cuts on the different recording sites glued one after the
  other. These events have been jitter corrected using the
  second order Taylor expansion.
  A vector of events positions where "actual" positions have
  been rounded to the nearest index.
  A vector of jitter values.
  
Details
------- 
(1) The data first and second derivatives are estimated first.
(2) Events are cut next on each of the three versions of the data.
(3) The global median event for each of the three versions are
obtained.
(4) Each event is then aligned on the median using a first order
Taylor expansion.
(5) If this alignment decreases the squared norm of the event
(6) an improvement is looked for using a second order expansion.
If this second order expansion still decreases the squared norm
and if the estimated jitter is larger than 1, the whole procedure
is repeated after cutting a new the event based on a better peak
position (7). 
"""

To end up with the function itself:

def mk_aligned_events(positions, data, before=14, after=30):
    <<mk_aligned_events-docstring>>
    <<mk_aligned_events-import-functions>>
    n_evts = len(positions)
    new_positions = positions.copy()
    jitters = np.zeros(n_evts)
    # Details (1)
    <<mk_aligned_events-dataD-and-dataDD>>
    # Details (2)
    <<mk_aligned_events-get-events>>
    # Details (3)
    <<mk_aligned_events-get-centers>>
    # Details (4)
    for evt_idx in range(n_evts):
        # Details (5)
        evt = evts[evt_idx,:]
        evt_pos = positions[evt_idx]
        <<mk_aligned_events-do-job-on-single-event-order1>>
        if h_order0_norm2 > h_order1_norm2:
            # Details (6)
            <<mk_aligned_events-do-job-on-single-event-order2>>
        else:
            jitter1 = 0
        if abs(round(jitter1)) > 0:
            # Details (7)
            evt_pos -= int(round(jitter1))
            evt = cut_sgl_evt(evt_pos,data=data,
                              before=before, after=after)
            <<mk_aligned_events-do-job-on-single-event-order1>>               
            if h_order0_norm2 > h_order1_norm2:
                <<mk_aligned_events-do-job-on-single-event-order2>>
            else:
                jitter1 = 0
        if sum(evt**2) > sum((h-jitter1*centerD-
                              jitter1**2/2*centerDD)**2):
            evts[evt_idx,:] = evt-jitter1*centerD- \
                jitter1**2/2*centerDD
        new_positions[evt_idx] = evt_pos 
        jitters[evt_idx] = jitter1
    return (evts, new_positions,jitters)

13.10. mk_center_dictionary

We define function mk_center_dictionary starting with its docstring:

""" Computes clusters 'centers' or templates and associated data.

Clusters' centers should be built such that they can be used for 
subtraction, this implies that we should make them long enough, on
both side of the peak, to see them go back to baseline. Formal
parameters before and after bellow should therefore be set to
larger values than the ones used for clustering. 

Parameters
----------
positions : a vector of spike times, that should all come from the
            same cluster and correspond to reasonably 'clean'
            events.
data : a data matrix.
before : the number of sampling point to keep before the peak.
after : the number of sampling point to keep after the peak.

Returns
-------
A dictionary with the following components:
  center: the estimate of the center (obtained from the median).
  centerD: the estimate of the center's derivative (obtained from
           the median of events cut on the derivative of data).
  centerDD: the estimate of the center's second derivative
            (obtained from the median of events cut on the second
            derivative of data).
  centerD_norm2: the squared norm of the center's derivative.
  centerDD_norm2: the squared norm of the center's second
                  derivative.
  centerD_dot_centerDD: the scalar product of the center's first
                        and second derivatives.
  center_idx: an array of indices generated by
              np.arange(-before,after+1).
 """

The function starts by evaluating the first two derivatives of the data (<<get-derivatives>>):

from scipy.signal import fftconvolve
from numpy import apply_along_axis as apply
dataD = apply(lambda x:
              fftconvolve(x,np.array([1,0,-1])/2.,'same'),
              1, data)
dataDD = apply(lambda x:
               fftconvolve(x,np.array([1,0,-1])/2.,'same'),
               1, dataD)
    

The function is defined next:

def mk_center_dictionary(positions, data, before=49, after=80):
    <<mk_center_dictionary-docstring>>
    <<mk_center_dictionary-get-derivatives>>
    evts = mk_events(positions, data, before, after)
    evtsD = mk_events(positions, dataD, before, after)
    evtsDD = mk_events(positions, dataDD, before, after)
    evts_median = apply(np.median,0,evts)
    evtsD_median = apply(np.median,0,evtsD)
    evtsDD_median = apply(np.median,0,evtsDD)
    return {"center" : evts_median, 
            "centerD" : evtsD_median, 
            "centerDD" : evtsDD_median, 
            "centerD_norm2" : np.dot(evtsD_median,evtsD_median),
            "centerDD_norm2" : np.dot(evtsDD_median,evtsDD_median),
            "centerD_dot_centerDD" : np.dot(evtsD_median,
                                            evtsDD_median), 
            "center_idx" : np.arange(-before,after+1)}

13.11. classify_and_align_evt

We now define with the following docstring (<<classify_and_align_evt-docstring>>):

"""Compares a single event to a dictionary of centers and returns
the name of the closest center if it is close enough or '?', the
corrected peak position and the remaining jitter.

Parameters
----------
evt_pos : a sampling point at which an event was detected.
data : a data matrix.
centers : a centers' dictionary returned by mk_center_dictionary.
before : the number of sampling point to consider before the peak.
after : the number of sampling point to consider after the peak.

Returns
-------
A list with the following components:
  The name of the closest center if it was close enough or '?'.
  The nearest sampling point to the events peak.
  The jitter: difference between the estimated actual peak
  position and the nearest sampling point.
"""

The first chunk of the function takes a dictionary of centers, centers, generated by mk_center_dictionary, defines two variables, cluster_names and n_sites, and builds a matrix of centers, centersM:

cluster_names = np.sort(list(centers))
n_sites = data.shape[0]
centersM = np.array([centers[c_name]["center"]\
                     [np.tile((-before <= centers[c_name]\
                               ["center_idx"]).\
                               __and__(centers[c_name]["center_idx"] \
                                       <= after), n_sites)]
                                       for c_name in cluster_names])

Extract the event, evt, to classify and subtract each center from it, delta, to find the closest one, cluster_idx, using the Euclidean squared norm (<<cluster_idx>>):

evt = cut_sgl_evt(evt_pos,data=data,before=before, after=after)
delta = -(centersM - evt)
cluster_idx = np.argmin(np.sum(delta**2,axis=1))    

Get the name of the selected cluster, good_cluster_name, and its 'time indices', good_cluster_idx. Then, extract the first two derivatives of the center, centerD and centerDD, their squared norms, centerD_norm2 and centerDD_norm2, and their dot product, centerD_dot_centerDD (<<get-centers>>):

good_cluster_name = cluster_names[cluster_idx]
good_cluster_idx = np.tile((-before <= centers[good_cluster_name]\
                            ["center_idx"]).\
                            __and__(centers[good_cluster_name]\
                                    ["center_idx"] <= after),
                                    n_sites)
centerD = centers[good_cluster_name]["centerD"][good_cluster_idx]
centerD_norm2 = np.dot(centerD,centerD)
centerDD = centers[good_cluster_name]["centerDD"][good_cluster_idx]
centerDD_norm2 = np.dot(centerDD,centerDD)
centerD_dot_centerDD = np.dot(centerD,centerDD)

Do a first order jitter correction where h contains the difference between the event and the center. Obtain the estimated jitter, jitter0 and the squared norm of the first order corrected residual, h_order1_norm2 (<<jitter-order-1>>):

h_order0_norm2 = sum(h**2)
h_dot_centerD = np.dot(h,centerD)
jitter0 = h_dot_centerD/centerD_norm2
h_order1_norm2 = sum((h-jitter0*centerD)**2)     

Do a second order jitter correction. Obtain the estimated jitter, jitter1 and the squared norm of the second order corrected residual, h_order2_norm2 (<<jitter-order-2>>):

h_dot_centerDD = np.dot(h,centerDD)
first = -2*h_dot_centerD + \
  2*jitter0*(centerD_norm2 - h_dot_centerDD) + \
  3*jitter0**2*centerD_dot_centerDD + \
  jitter0**3*centerDD_norm2
second = 2*(centerD_norm2 - h_dot_centerDD) + \
  6*jitter0*centerD_dot_centerDD + \
  3*jitter0**2*centerDD_norm2
jitter1 = jitter0 - first/second
h_order2_norm2 = sum((h-jitter1*centerD-jitter1**2/2*centerDD)**2)

Now define the function:

def classify_and_align_evt(evt_pos, data, centers,
                           before=14, after=30):
    <<classify_and_align_evt-docstring>>
    <<classify_and_align_evt-centersM>>
    <<classify_and_align_evt-cluster_idx>>
    <<classify_and_align_evt-get-centers>>
    h = delta[cluster_idx,:]
    <<classify_and_align_evt-jitter-order-1>>
    if h_order0_norm2 > h_order1_norm2:
        <<classify_and_align_evt-jitter-order-2>>
        if h_order1_norm2 <= h_order2_norm2:
            jitter1 = jitter0
    else:
        jitter1 = 0
    if abs(round(jitter1)) > 0:
        evt_pos -= int(round(jitter1))
        evt = cut_sgl_evt(evt_pos,data=data,
                          before=before, after=after)
        h = evt - centers[good_cluster_name]["center"]\
          [good_cluster_idx]
        <<classify_and_align_evt-jitter-order-1>>  
        if h_order0_norm2 > h_order1_norm2:
            <<classify_and_align_evt-jitter-order-2>>
            if h_order1_norm2 <= h_order2_norm2:
                jitter1 = jitter0
        else:
            jitter1 = 0
    if sum(evt**2) > sum((h-jitter1*centerD-jitter1**2/2*centerDD)**2):
        return [cluster_names[cluster_idx], evt_pos, jitter1]
    else:
        return ['?',evt_pos, jitter1]

13.12. predict_data

We define function predict_data that creates an ideal data trace given events' positions, events' origins and a clusters' catalog. We start with the docstring:

"""Predicts ideal data given a list of centers' names, positions,
jitters and a dictionary of centers.

Parameters
----------
class_pos_jitter_list : a list of lists returned by
                        classify_and_align_evt.
centers_dictionary : a centers' dictionary returned by
                     mk_center_dictionary.
nb_channels : the number of recording channels.
data_length : the number of sampling points.

Returns
-------
A matrix of ideal (noise free) data with nb_channels rows and
data_length columns.
"""

And the function:

def predict_data(class_pos_jitter_list,
                 centers_dictionary,
                 nb_channels=4,
                 data_length=300000):
    <<predict_data-docstring>>
    ## Create next a matrix that will contain the results
    res = np.zeros((nb_channels,data_length))
    ## Go through every list element
    for class_pos_jitter in class_pos_jitter_list:
        cluster_name = class_pos_jitter[0]
        if cluster_name != '?':
            center = centers_dictionary[cluster_name]["center"]
            centerD = centers_dictionary[cluster_name]["centerD"]
            centerDD = centers_dictionary[cluster_name]["centerDD"]
            jitter = class_pos_jitter[2]
            pred = center + jitter*centerD + jitter**2/2*centerDD
            pred = pred.reshape((nb_channels,len(center)//nb_channels))
            idx = centers_dictionary[cluster_name]["center_idx"] + \
              class_pos_jitter[1]
            ## Make sure that the event is not too close to the
            ## boundaries
            within = np.bitwise_and(0 <= idx, idx < data_length)
            kw = idx[within]
            res[:,kw] += pred[:,within]
    return res

Author: Christophe Pouzat

Validate